﻿""" Miguel Nacenta 2022
This class uses the psy module by Joseph J Glavan to implement the psi adaptive staircase procedure by Kontsevich et al 1999
"""

import psy as ps
import numpy as np
import warnings as wr

class StaircaseProcedure:

    """ This class controls the staircase for one point of the scale in one direction.
        There can be as many of this object as necessary. Normally there will be one staircaseProcedure object
        for each simultaneous point and direction being tested.
        
    """

    # Constants
    NUMBER_OF_BETA_DIVISIONS = 30.0 # how much precision we have to determine the beta (slope)
    MAXIMUM_BETA = 0.25 # proportion of the parameter range that is tested
    DELTA = 0.05 # The proportion of wrong answers even when the parameter differences are maxed out

    def normalize(self,value): # turns numbers in parameter space into normalised space (from 0 to 1)

        """ Normalises a value from parameter space to from 0 to 1, where 0 is te value that we are testing
            and 1 is the maximum difference that we can create within the parameter range (value of this
            depends on where the tested value is located)
        """

        if (self.direction == 1):
            return (1.0-(value-self.offset)*(self.scale))
        else:
            return (value-self.offset)*(self.scale)

    def denormalize(self,value): # turns numbers in normalised space into parameter space (specified by the parameter range)

        """ Denormalises a value from the normalised space to parameter space (see normalise())
        """
                
        if (self.direction == 1):
            return (1-value)*(1.0/self.scale)+self.offset
        else:
            return value*(1.0/self.scale)+self.offset

    def find_nearest(self,array,value):
        idx = (np.abs(array-value)).argmin()
        return [idx, array[idx]]

    def __init__(self,parameterRange,parameterPrecision,parameterPoint,direction):

        """
            Constructor of this class

            parameterRange - a list that contains the lowest value and the highest value of this parameter's space, in that order.

            parameterPrecision - the smallest granularity of change of this parameter. Note: it might be necessary
                                 to restrict the natural granularity of some of the parameters becaue too many will 
                                 make the algorithm too slow. I'm sure it works OK with about 70.

            parameterPoint - The proportion of the parameterRange where the parameterPoint is being tested. E.g., if we are testing 5 
                             equidistant values in the parameter space, the parameter points would be 0.0, 0.25, 0.5, 0.75 and 1.0. Note
                             that you need a different StaircaseProcedure object for each parameter point being tested.

            direction - Whether the point is being tested from above or below. + 1 means point
                        approached from below, -1 means point approached from above

            Notes:
                - This class can proceed indefinitely. The number of trials has to be controlled
                  from the code controlling this class.

            
        """

        # Fields that store the creation parameters of the object (to simplify programming)
        self.originalParameterPoint = parameterPoint 
        self.originalDirection = direction

        self.parameterRange = parameterRange
        self.parameterPrecision = parameterPrecision
        self.direction = direction
        self.parameterPoint = parameterPoint
        if (self.direction == 1):
            self.parameterPoint = 1.0-parameterPoint

        self.possibleParameterValues = np.linspace(parameterRange[0],parameterRange[1],round((parameterRange[1]-parameterRange[0])/parameterPrecision)+1,True)
        self.parameterSize = len(self.possibleParameterValues)

        if (self.parameterSize > 150):
            wr.warn("The parameter's size is a bit large. Consider reducing it.")

        # for conversion between unnormalised and normalised values
        self.offset = float(parameterRange[0]) # subtract to normalise
        self.scale = 1.0/(float(parameterRange[1])-float(parameterRange[0])) # scale to normalise
        self.numDivisions = (float(parameterRange[1]) - float(parameterRange[0]))/float(parameterPrecision)
        self.normPrecision = float(parameterPrecision)*self.scale

        #self.vecNormalize = np.vectorize(self.normalize)
        self.possibleParameterValuesNorm = self.normalize(self.possibleParameterValues)
        if (direction == 1):
            self.possibleParameterValuesNorm = self.possibleParameterValuesNorm[::-1]

        # for the algorithm and use of Psi 
        self.intensityRange = [0.0,1.0] # the range of possible values of the intensity (normalised)
        self.alphaRange = [self.intensityRange[0],self.intensityRange[1]] # we start with the full range, modify it below

        self.closestParameterValue = None
        # calculate the index and the closest value to the suggested parameter value 
        [indexOfParameterValue,self.closestParameterValue] = self.find_nearest(self.possibleParameterValuesNorm,self.parameterPoint)
        if (indexOfParameterValue < 0) | (indexOfParameterValue >= self.parameterSize) :
            print("indexOfParameterValue: "+str(indexOfParameterValue)+" closestParameterValue: "+str(self.closestParameterValue))
            wr.warn("There is something wrong on the value. The parameter point: "+str(parameterPoint)+" might be being approached from the wrong end.")

        ixAlphaLimit = None
        if self.closestParameterValue >= self.parameterPoint:     
            ixAlphaLimit = indexOfParameterValue
        else:
            ixAlphaLimit = indexOfParameterValue+1

        self.alphaRange[0] = self.possibleParameterValuesNorm[ixAlphaLimit] # the range goes from the value immediately above the tested value and 1
        
        self.alphaPrecision = self.normPrecision
        self.betaPrecision = self.MAXIMUM_BETA/self.NUMBER_OF_BETA_DIVISIONS
        self.betaRange = [self.betaPrecision,self.MAXIMUM_BETA]
        self.delta = self.DELTA
        self.stepType = "lin"

        # This is the big object that does the smart stuff (notice, precision for alpha is same as precision for parameter)
        self.myPsy = ps.PsiObject(self.intensityRange, self.alphaRange, self.betaRange,
				          self.normPrecision,self.normPrecision,self.betaPrecision, 
				          delta = self.delta, stepType = self.stepType,
				          TwoAFC = True , prior = None)

        self.myPsy.update() # initialisation (update needs to be called once)

    def getNextParameterValue(self):

        """ Calculates the next parameter values to show in the trials
            The first value is the new parameter value, the second is the reference point.
        """
        next = self.myPsy.nextIntensity

        return self.denormalize(next)

    def getAssumedTestPoint(self):

        """ Gets the value in parameter space of the point that the algorithm considers the reference value.
            This is because the parameterPoint passed to the object might not coincide with an actual value, 
            given the precision.
        """

        return self.denormalize(self.closestParameterValue)

    def update(self,response = None):
        
        """ Provides the next response. 0 means the trial was failed, 1 means the trial was successful (the subject saw the right difference)
        """                
        self.myPsy.update(response)

    def getEstimatedNormalisedLambda(self):

        """ Returns the current estimated lambda, with alpha from 0 to 1 and beta from 0 to 1
        """
        return self.myPsy.estimateLambda()

    def getActualLambda(self):

        """ Returns the current estimated lambda, with alpha from 0 to 1 and beta from 0 to 1
        """
        estimatedLambda = self.myPsy.estimateLambda()
        return [self.denormalize(estimatedLambda[0]),estimatedLambda[1]/self.scale]

    def getProbabilitiesOfLambda(self):

        x = np.linspace(self.alphaRange[0], self.alphaRange[1], round((self.alphaRange[1]-self.alphaRange[0])/self.alphaPrecision)+1, True)
        y = np.linspace(self.betaRange[0], self.betaRange[1], round((self.betaRange[1]-self.betaRange[0])/self.betaPrecision)+1, True)

        X, Y = np.meshgrid(y,x)
        Z = self.myPsy._probLambda[0,:,:,0]

        return [X,Y,Z]
 